{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install allennlp\n",
    "!git clone https://github.com/mhagiwara/realworldnlp.git\n",
    "%cd realworldnlp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Dict\n",
    "\n",
    "import numpy as np\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from allennlp.data.dataset_readers.stanford_sentiment_tree_bank import \\\n",
    "    StanfordSentimentTreeBankDatasetReader\n",
    "from allennlp.data.iterators import BucketIterator\n",
    "from allennlp.data.token_indexers import SingleIdTokenIndexer\n",
    "from allennlp.data.vocabulary import Vocabulary\n",
    "from allennlp.models import Model\n",
    "from allennlp.modules.seq2vec_encoders import Seq2VecEncoder, CnnEncoder\n",
    "from allennlp.modules.text_field_embedders import TextFieldEmbedder, BasicTextFieldEmbedder\n",
    "from allennlp.modules.token_embedders import Embedding\n",
    "from allennlp.nn.util import get_text_field_mask\n",
    "from allennlp.training.metrics import CategoricalAccuracy, F1Measure\n",
    "from allennlp.training.trainer import Trainer\n",
    "\n",
    "from realworldnlp.predictors import SentenceClassifierPredictor"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EMBEDDING_DIM = 128\n",
    "HIDDEN_DIM = 128"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "@Model.register(\"cnn_classifier\")\n",
    "class CnnClassifier(Model):\n",
    "    def __init__(self,\n",
    "                 embedder: TextFieldEmbedder,\n",
    "                 encoder: Seq2VecEncoder,\n",
    "                 vocab: Vocabulary,\n",
    "                 positive_label: str = '4') -> None:\n",
    "        super().__init__(vocab)\n",
    "        self.embedder = embedder\n",
    "\n",
    "        self.encoder = encoder\n",
    "\n",
    "        self.linear = torch.nn.Linear(in_features=encoder.get_output_dim(),\n",
    "                                      out_features=vocab.get_vocab_size('labels'))\n",
    "\n",
    "        positive_index = vocab.get_token_index(positive_label, namespace='labels')\n",
    "        self.accuracy = CategoricalAccuracy()\n",
    "        self.f1_measure = F1Measure(positive_index)\n",
    "\n",
    "        self.loss_function = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "    def forward(self,\n",
    "                tokens: Dict[str, torch.Tensor],\n",
    "                label: torch.Tensor = None) -> torch.Tensor:\n",
    "        mask = get_text_field_mask(tokens)\n",
    "        embeddings = self.embedder(tokens)\n",
    "        encoder_out = self.encoder(embeddings, mask)\n",
    "        logits = self.linear(encoder_out)\n",
    "\n",
    "        output = {\"logits\": logits}\n",
    "        if label is not None:\n",
    "            self.accuracy(logits, label)\n",
    "            self.f1_measure(logits, label)\n",
    "            output[\"loss\"] = self.loss_function(logits, label)\n",
    "\n",
    "        return output\n",
    "\n",
    "    def get_metrics(self, reset: bool = False) -> Dict[str, float]:\n",
    "        precision, recall, f1_measure = self.f1_measure.get_metric(reset)\n",
    "        return {'accuracy': self.accuracy.get_metric(reset),\n",
    "                'precision': precision,\n",
    "                'recall': recall,\n",
    "                'f1_measure': f1_measure}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Note: CnnEncoder (with ngram filter size = 5) requires the padding length >= 5\n",
    "token_indexer = SingleIdTokenIndexer(token_min_padding_length=5)\n",
    "reader = StanfordSentimentTreeBankDatasetReader(token_indexers={'tokens': token_indexer})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_dataset = reader.read('https://s3.amazonaws.com/realworldnlpbook/data/stanfordSentimentTreebank/trees/train.txt')\n",
    "dev_dataset = reader.read('https://s3.amazonaws.com/realworldnlpbook/data/stanfordSentimentTreebank/trees/dev.txt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "vocab = Vocabulary.from_instances(train_dataset + dev_dataset,\n",
    "                                  min_count={'tokens': 3})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),\n",
    "                            embedding_dim=EMBEDDING_DIM)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "word_embeddings = BasicTextFieldEmbedder({\"tokens\": token_embedding})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = CnnEncoder(\n",
    "    embedding_dim=EMBEDDING_DIM,\n",
    "    num_filters=8,\n",
    "    ngram_filter_sizes=(2, 3, 4, 5))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model = CnnClassifier(word_embeddings, encoder, vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator = BucketIterator(batch_size=32, sorting_keys=[(\"tokens\", \"num_tokens\")])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "iterator.index_with(vocab)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "optimizer = optim.Adam(model.parameters(), lr=1e-3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "trainer = Trainer(model=model,\n",
    "                  optimizer=optimizer,\n",
    "                  iterator=iterator,\n",
    "                  train_dataset=train_dataset,\n",
    "                  validation_dataset=dev_dataset,\n",
    "                  patience=10,\n",
    "                  num_epochs=20)\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictor = SentenceClassifierPredictor(model, dataset_reader=reader)\n",
    "logits = predictor.predict('This is the best movie ever!')['logits']\n",
    "label_id = np.argmax(logits)\n",
    "\n",
    "print(model.vocab.get_token_from_index(label_id, 'labels'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
